Skip to content

Conversation

@akroviakov
Copy link
Contributor

@akroviakov akroviakov commented Oct 16, 2025

This PR is the first, minimally working, and somewhat crude application of the xegpu::uArch infra in the uarch-sensitive parts of XeGPU. We completely remove the XeGPUTargetInfo.h and rely on the attached target to the GPU module.

This PR adds support for inst_data default setting for store_nd, prefetch_nd, dpas, scatter, gather. See propagate-layout-inst-data.mlir.

Some points to consider:

  • Type verification (XeGPUDialect.cpp, TensorDescType::verify) gets trickier, because there is no way to get the target attribute. I'd be glad to hear some feedback on it. For now, I use a constexpr placeholder.
  • All anchor ops need uArch, at least to query the subgroup size for getDefaultSIMTLayoutInfo. Inst_data is not part of getDefaultSIMTLayoutInfo, because it requires querying a specific operation in uArch.

@llvmbot
Copy link
Member

llvmbot commented Oct 16, 2025

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Artem Kroviakov (akroviakov)

Changes

This PR is the first, minimally working, and somewhat crude application of the xegpu::uArch infra in the uarch-sensitive parts of XeGPU. We completely remove the XeGPUTargetInfo.h and rely on the attached target to the GPU module. Due to the early uncertainty of the design, I only consider pvc, and all of the new instructions provide the minimal interface.

This PR adds support for inst_data default setting for store_nd, prefetch_nd, dpas, scatter, gather

Some points to consider:

  • LLVM does not use C++ RTTI, its own dynamic polymorphism requires manually amending the types for dyn_cast to work. This becomes crucial if you want to check for or get a specific uArch. An example of it can be found in requireTranspose() in XeGPUSubgroupDistribute.cpp where we still need to check for hardcoded strings, instead of isa<>/dyn_cast<>.
  • Should uArch be exposed via a shared pointer to a constant structure or as a reference to a constant static structure? It depends on whether we allow for some fallback (chip string is not present or no uarch found for it, i.e., if(!uArch){...} ) or strictly require a valid uArch (then an invalid uArch is not even possible, llvm_unreachable in getUArch()). Generally, I lean towards trying a reference to a static constant for simplicity, but the uArch exposure to use cases may still be too small to judge yet. The shared_ptr version currently remains as the most flexible one.
  • Type verification (XeGPUDialect.cpp, TensorDescType::verify) gets trickier, because there is no way to get the target attribute. I'd be glad to hear some feedback on it. For now, I use a constexpr placeholder.
  • All anchor ops need uArch, at least to query the subgroup size for getDefaultSIMTLayoutInfo. Inst_data is not part of getDefaultSIMTLayoutInfo, because it requires querying a specific operation in uArch.

I am eager to gather feedback both for the uArch API and usage in general, so feel free to ask and propose changes.

My impression is that uArch is a nice for our purposes, but from the initial experience, as of now, it appears a bit bloated (two std maps of shared pointers to instructions per instance of uArch for seemingly compile time constant data) and tricky to use (e.g., no support for LLVM polymorphism). I may have missed some justifications for this in the uArch PR though, so it would be good to reiterate here.


Patch is 57.33 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/163801.diff

11 Files Affected:

  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td (+23-11)
  • (removed) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h (-30)
  • (modified) mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td (+6-1)
  • (modified) mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h (+74-4)
  • (modified) mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h (+15-2)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+9-7)
  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp (+164-62)
  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp (+16-10)
  • (modified) mlir/test/Dialect/XeGPU/move-gpu-func-to-warp-op.mlir (+1-1)
  • (added) mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir (+51)
  • (modified) mlir/test/Dialect/XeGPU/propagate-layout.mlir (+59-23)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 5695d5d515d7f..ec236d702de0d 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -379,29 +379,41 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
   );
 
   let builders = [
-    AttrBuilder<(ins "llvm::ArrayRef<int32_t>": $lane_layout,
+    AttrBuilder<(ins "llvm::ArrayRef<int32_t>": $inst_data,
+                      "llvm::ArrayRef<int32_t>": $lane_layout,
                      "llvm::ArrayRef<int32_t>": $lane_data),
       [{
         auto sg_layout = DenseI32ArrayAttr();
         auto sg_data = DenseI32ArrayAttr();
-        auto inst_data = DenseI32ArrayAttr();
         auto order = DenseI32ArrayAttr();
-        return $_get($_ctxt, sg_layout, sg_data, inst_data,
+        return $_get($_ctxt, sg_layout, sg_data,
+                     DenseI32ArrayAttr::get($_ctxt, inst_data),
                      DenseI32ArrayAttr::get($_ctxt, lane_layout),
                      DenseI32ArrayAttr::get($_ctxt, lane_data), order);
       }]>,
     AttrBuilder<(ins "llvm::ArrayRef<int32_t>": $lane_layout,
-                     "llvm::ArrayRef<int32_t>": $lane_data,
-                     "llvm::ArrayRef<int32_t>": $order),
+                     "llvm::ArrayRef<int32_t>": $lane_data),
       [{
-        return $_get($_ctxt,
-                     /*sg_layout =*/ nullptr,
-                     /*sg_data   =*/ nullptr,
-                     /*inst_data =*/ nullptr,
+        auto sg_layout = DenseI32ArrayAttr();
+        auto sg_data = DenseI32ArrayAttr();
+        auto inst_data = DenseI32ArrayAttr();
+        auto order = DenseI32ArrayAttr();
+        return $_get($_ctxt, sg_layout, sg_data, inst_data,
                      DenseI32ArrayAttr::get($_ctxt, lane_layout),
-                     DenseI32ArrayAttr::get($_ctxt, lane_data),
-                     DenseI32ArrayAttr::get($_ctxt, order));
+                     DenseI32ArrayAttr::get($_ctxt, lane_data), order);
       }]>,
+    // AttrBuilder<(ins "llvm::ArrayRef<int32_t>": $lane_layout,
+    //                  "llvm::ArrayRef<int32_t>": $lane_data,
+    //                  "llvm::ArrayRef<int32_t>": $order),
+    //   [{
+    //     return $_get($_ctxt,
+    //                  /*sg_layout =*/ nullptr,
+    //                  /*sg_data   =*/ nullptr,
+    //                  /*inst_data =*/ nullptr,
+    //                  DenseI32ArrayAttr::get($_ctxt, lane_layout),
+    //                  DenseI32ArrayAttr::get($_ctxt, lane_data),
+    //                  DenseI32ArrayAttr::get($_ctxt, order));
+    //   }]>,
     AttrBuilder<(ins "DenseI32ArrayAttr": $lane_layout,
                      "DenseI32ArrayAttr": $lane_data,
                      "DenseI32ArrayAttr": $order),
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h
deleted file mode 100644
index 8aa9536cb67c1..0000000000000
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h
+++ /dev/null
@@ -1,30 +0,0 @@
-//===- XeGPUTargetInfo.h - Target constants ---------------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_XEGPU_IR_XEGPUTARGETINFO_H_
-#define MLIR_DIALECT_XEGPU_IR_XEGPUTARGETINFO_H_
-
-namespace mlir {
-namespace xegpu {
-/// HW dependent constants.
-/// TODO: These constants should be queried from the target information.
-namespace targetinfo {
-constexpr unsigned subgroupSize = 16; // How many lanes in a subgroup.
-/// If DPAS A or B operands have low precision element types they must be packed
-/// according to the following sizes.
-constexpr unsigned packedSizeInBitsForDefault =
-    16; // Minimum packing size per register for DPAS A.
-constexpr unsigned packedSizeInBitsForDpasB =
-    32; // Minimum packing size per register for DPAS B.
-constexpr unsigned packedSizeInBitsForGatherScatter =
-    32; // Minimum packing size per register for Gather and Scatter ops.
-} // namespace targetinfo
-} // namespace xegpu
-} // namespace mlir
-
-#endif // MLIR_DIALECT_XEGPU_IR_XEGPUTARGETINFO_H_
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index 564d9c4d5422b..5ef1d499d618f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -43,7 +43,12 @@ def XeGPUPropagateLayout : Pass<"xegpu-propagate-layout"> {
   let options = [Option<
     "printOnly", "print-analysis-only", "bool",
     /*default=*/"false",
-    "Print the result of layout propagation analysis and exit.">];
+    "Print the result of layout propagation analysis and exit.">,
+    Option<
+    "assumeUnrolled", "assume-unrolled", "bool",
+    /*default=*/"false",
+    "If the input IR has SG-sized tiles matching instruction sizes, omit `inst_data`.">
+  ];
 }
 
 def XeGPUWgToSgDistribute : Pass<"xegpu-wg-to-sg-distribute"> {
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index 0519f7b2e277d..5cb6d61336391 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -42,12 +42,59 @@ struct Xe2Plus : public uArch {
               &instrs = {})
       : uArch(archName, archDescription, regInfo, cacheInfo, instrs),
         xeCore(xeCore) {}
+  int getSubgroupSize() const override { return 16; }
+  int getPackedFormatBitSizeGatherScatter() const override { return 32; }
+  int getPackedFormatBitSize() const override { return 16; }
+  std::optional<int> getPackedFormatBitSizeDpasB() const override { return 32; }
+};
+
+//===----------------------------------------------------------------------===//
+// uArch instructions
+//===----------------------------------------------------------------------===//
+struct StoreNdInstruction : public Instruction {
+  StoreNdInstruction()
+      : Instruction(InstructionKind::STORE_ND, InstructionScope::Subgroup) {}
+
+  // Source :
+  // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups.html#_add_a_new_section_6_13_x_sub_group_read_and_write_functions
+  // Reads 1, 2, 4, or 8 uints of data for each work item in the sub-group from
+  // the specified pointer
+  llvm::SmallVector<int> getSortedLaneVectorLengths() { return {1, 2, 4, 8}; }
+};
+
+struct LoadNdInstruction : public Instruction {
+  LoadNdInstruction()
+      : Instruction(InstructionKind::LOAD_ND, InstructionScope::Subgroup) {}
+
+  // Source :
+  // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups.html#_add_a_new_section_6_13_x_sub_group_read_and_write_functions
+  // Writes 1, 2, 4, or 8 uints of data for each work item in the sub-group to
+  // the specified pointer.
+  llvm::SmallVector<int> getSortedLaneVectorLengths() { return {1, 2, 4, 8}; }
+};
+
+struct PrefetchNdInstruction : public Instruction {
+  PrefetchNdInstruction()
+      : Instruction(InstructionKind::PREFETCH_ND, InstructionScope::Subgroup) {}
+
+  // Source :
+  // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_buffer_prefetch.html#_add_a_new_section_6_15_x_sub_group_prefetch_functions
+  llvm::SmallVector<int> getSortedLaneVectorLengths(int elementBitwidth) {
+    if (elementBitwidth == 8 || elementBitwidth == 16)
+      return {1, 2, 4, 8, 16};
+    else if (elementBitwidth == 32 || elementBitwidth == 64)
+      return {1, 2, 4, 8};
+    else
+      llvm_unreachable(
+          "Unsupported element bitwidth for PrefetchNdInstruction");
+  }
 };
 
-// struct to represent DPAS instruction
 struct DPASInstruction : public Instruction, public MMAInstructionInterface {
   DPASInstruction()
       : Instruction(InstructionKind::DPAS, InstructionScope::Subgroup) {}
+  // Source:
+  // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html
 
   // Override all virtuals from MatrixOpInterface
   virtual llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
@@ -72,6 +119,9 @@ struct DPASInstruction : public Instruction, public MMAInstructionInterface {
   virtual llvm::SmallVector<uint32_t, 8> getSupportedN(Type type) override;
 };
 
+//===----------------------------------------------------------------------===//
+// uArch instructions
+//===----------------------------------------------------------------------===//
 struct PVCuArch : public Xe2Plus {
   // Maintaines ownership of the instructions owned by PVUarch
   llvm::SmallVector<std::shared_ptr<Instruction>, 8> owned_instructions;
@@ -101,9 +151,15 @@ struct PVCuArch : public Xe2Plus {
         CacheInfo(512 * 1024, 64, CacheHierarchyLevel::L2));
 
     // Add the instructions-
-    auto dpas = std::make_shared<DPASInstruction>();
-    instructions.emplace(dpas->getInstructionKind(), dpas);
-    owned_instructions.push_back(dpas);
+    llvm::SmallVector<std::shared_ptr<Instruction>> instructionsToAdd{
+        std::make_shared<DPASInstruction>(),
+        std::make_shared<StoreNdInstruction>(),
+        std::make_shared<LoadNdInstruction>(),
+        std::make_shared<PrefetchNdInstruction>()};
+    for (auto &inst : instructionsToAdd) {
+      instructions.emplace(inst->getInstructionKind(), inst);
+      owned_instructions.push_back(inst);
+    }
   }
 };
 
@@ -139,10 +195,24 @@ struct BMGuArch : public Xe2Plus {
     owned_instructions.push_back(dpas);
   }
 };
+
+inline std::shared_ptr<uArch> getUArch(const std::string &archName) {
+  if (archName == "pvc")
+    return std::make_shared<PVCuArch>();
+  else if (archName == "bmg")
+    return std::make_shared<BMGuArch>();
+  else
+    return nullptr;
+}
+
 } // namespace uArch
 } // namespace xegpu
 } // namespace mlir
 
+//===----------------------------------------------------------------------===//
+// Instruction implementations
+//===----------------------------------------------------------------------===//
+
 inline llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
 DPASInstruction::getSupportedShapes(Type dataType, MMAOpndKind matrixType) {
   auto combineVectors = [](const llvm::SmallVector<uint32_t, 8> &a,
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index 955994ea5ecf5..0f5b1282f0e24 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -32,8 +32,11 @@ namespace uArch {
 // An enum class to represent the scope of an instruction
 enum class InstructionScope { Lane, Subgroup, Workgroup, Cluster };
 enum class InstructionKind {
-  DPAS, // Dot Product Accumulate Systolic (DPAS) is a matrix
-        // multiply-add operation
+  DPAS,       // Dot Product Accumulate Systolic (DPAS) is a matrix
+              // multiply-add operation
+  STORE_ND,   // Subgroup-level 2D block write instruction
+  LOAD_ND,    // Subgroup-level 2D block load instruction
+  PREFETCH_ND // Subgroup-level 2D block prefetch instruction
   // @TODO: Add more instructions as needed
 };
 
@@ -148,6 +151,16 @@ struct uArch {
 
   const std::string &getDescription() const { return description; }
 
+  virtual int getSubgroupSize() const = 0;
+  virtual int getPackedFormatBitSizeGatherScatter() const = 0;
+  virtual int getPackedFormatBitSize() const = 0;
+  virtual std::optional<int> getPackedFormatBitSizeDpasB() const = 0;
+
+  std::shared_ptr<Instruction> getInstruction(InstructionKind instKind) const {
+    assert(instructions.find(instKind) != instructions.end());
+    return instructions.at(instKind);
+  }
+
   const std::map<RegisterFileType, RegisterFileInfo> &
   getRegisterFileInfo() const {
     return registerFileInfo;
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 9beb22d517473..afda04fa71105 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -11,7 +11,7 @@
 #include "mlir/Dialect/Index/IR/IndexOps.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
-#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
@@ -226,8 +226,10 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
   }
 
   if (inst_data && lane_layout && inst_data.size() != lane_layout.size()) {
-    return emitError()
-           << "expected inst_data and lane_layout to have the same rank";
+    return emitError() << "expected inst_data and lane_layout to have the same "
+                          "rank, got inst_data "
+                       << inst_data.size() << ", lane_layout "
+                       << lane_layout.size();
   }
 
   // sg_data is optional for Workgroup layout, but its presence requires
@@ -565,10 +567,10 @@ TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
 
   // for gather and scatter ops, Low-precision types are packed in 32-bit units.
   unsigned bitWidth = elementType.getIntOrFloatBitWidth();
-  int chunkAlignmentFactor =
-      bitWidth < targetinfo::packedSizeInBitsForGatherScatter
-          ? targetinfo::packedSizeInBitsForGatherScatter / bitWidth
-          : 1;
+  constexpr int packingBitSizeGatherScatter{32};
+  int chunkAlignmentFactor = bitWidth < packingBitSizeGatherScatter
+                                 ? packingBitSizeGatherScatter / bitWidth
+                                 : 1;
   auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
   if (scatterAttr) {
     int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 8fab255d6347f..9c09908f3547d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -14,7 +14,6 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
-#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"
 #include "mlir/Dialect/XeGPU/Transforms/Passes.h"
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/IR/Attributes.h"
@@ -37,6 +36,8 @@
 #include "llvm/Support/LogicalResult.h"
 #include "llvm/Support/raw_ostream.h"
 
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
+
 namespace mlir {
 namespace xegpu {
 #define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT
@@ -104,6 +105,8 @@ struct LayoutInfo {
 
   SmallVector<int> getLaneData() const;
 
+  SmallVector<int> getInstData() const;
+
   bool isSliceLayout() const {
     if (!isAssigned())
       return false;
@@ -137,6 +140,13 @@ SmallVector<int> LayoutInfo::getLaneData() const {
                              [](int64_t val) { return static_cast<int>(val); });
 }
 
+SmallVector<int> LayoutInfo::getInstData() const {
+  if (!isAssigned())
+    return {};
+  return llvm::map_to_vector(storage.getEffectiveInstDataAsInt(),
+                             [](int64_t val) { return static_cast<int>(val); });
+}
+
 void LayoutInfo::print(raw_ostream &os) const {
   if (isAssigned()) {
     os << storage;
@@ -174,12 +184,14 @@ LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const {
 
   SmallVector<int32_t> laneLayout;
   SmallVector<int32_t> laneData;
+  SmallVector<int32_t> instData;
   for (int64_t idx : permutation) {
     laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx]));
     laneData.push_back(static_cast<int32_t>(getLaneData()[idx]));
+    instData.push_back(static_cast<int32_t>(getInstData()[idx]));
   }
-  return LayoutInfo(
-      xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData));
+  return LayoutInfo(xegpu::LayoutAttr::get(storage.getContext(), instData,
+                                           laneLayout, laneData));
 }
 
 //===----------------------------------------------------------------------===//
@@ -199,20 +211,33 @@ struct LayoutInfoLattice : public Lattice<LayoutInfo> {
 /// Helper Function to get the default layout for uniform values like constants.
 /// For 1D vector, lane_layout is [subgroupSize] and lane_data is [1].
 /// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
-static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
-                                           unsigned rank) {
+static LayoutInfo
+getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx, unsigned rank,
+                         std::shared_ptr<xegpu::uArch::uArch> &uArch,
+                         ArrayRef<int> instData) {
   assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
   if (rank == 1) {
     return LayoutInfo(
-        xegpu::LayoutAttr::get(ctx, {xegpu::targetinfo::subgroupSize}, {1}));
+        xegpu::LayoutAttr::get(ctx, instData, {uArch->getSubgroupSize()}, {1}));
   }
   return LayoutInfo(xegpu::LayoutAttr::get(
-      ctx, {1, xegpu::targetinfo::subgroupSize}, {1, 1}));
+      ctx, instData, {1, uArch->getSubgroupSize()}, {1, 1}));
+}
+
+static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
+                                           unsigned rank, int subgroupSize) {
+  assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
+  if (rank == 1) {
+    return LayoutInfo(xegpu::LayoutAttr::get(ctx, {subgroupSize}, {1}));
+  }
+  return LayoutInfo(xegpu::LayoutAttr::get(ctx, {1, subgroupSize}, {1, 1}));
 }
 
 /// Helper to get the default layout for a vector type.
-static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
-                                           bool isScattered = false) {
+static LayoutInfo
+getDefaultSIMTLayoutInfo(VectorType vectorTy,
+                         std::shared_ptr<xegpu::uArch::uArch> &uArch,
+                         ArrayRef<int> instData, bool isScattered = false) {
   // Expecting a 1D or 2D vector.
   assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
          "Expected 1D or 2D vector.");
@@ -221,29 +246,31 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
          "Expected int or float element type.");
   // If the rank is 1, then return default layout for 1D vector.
   if (vectorTy.getRank() == 1)
-    return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1);
+    return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch, instData);
   // Packing factor is determined by the element type bitwidth.
   int packingFactor = 1;
   unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
   if (isScattered) {
     packingFactor =
-        bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter
-            ? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth
+        bitwidth < uArch->getPackedFormatBitSizeGatherScatter()
+            ? uArch->getPackedFormatBitSizeGatherScatter() / bitwidth
             : 1;
-    return LayoutInfo(xegpu::LayoutAttr::get(
-        vectorTy.getContext(), {xegpu::targetinfo::subgroupSize, 1},
-        {1, packingFactor}));
+    return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), instData,
+                                             {uArch->getSubgroupSize(), 1},
+                                             {1, packingFactor}));
   }
-  if (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault)
-    packingFactor = xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth;
-  return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
-                                           {1, xegpu::targetinfo::subgroupSize},
+  if (bitwidth < uArch->getPackedFormatBitSize())
+    packingFactor = uArch->getPackedFormatBitSize() / bitwidth;
+  return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), instData,
+                                           {1, uArch->getSubgroupSize()},
                                            {1, packingFactor}));
 }
 
 /// Helper to get the default layout for a vector type.
-static LayoutInfo get...
[truncated]

@akroviakov akroviakov requested a review from mshahneo October 16, 2025 14:56
@akroviakov akroviakov force-pushed the akroviak/xegpu-enhance-layout-prop branch 2 times, most recently from 63d11b8 to 4b99cdd Compare October 16, 2025 17:00
auto dpas = std::make_shared<DPASInstruction>();
instructions.emplace(dpas->getInstructionKind(), dpas);
owned_instructions.push_back(dpas);
llvm::SmallVector<std::shared_ptr<Instruction>> instructionsToAdd{

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - formatting

@@ -42,12 +40,61 @@ struct Xe2Plus : public uArch {
&instrs = {})
: uArch(archName, archDescription, regInfo, cacheInfo, instrs),
xeCore(xeCore) {}
int getSubgroupSize() const override { return 16; }
unsigned getPackedFormatBitSizeGatherScatter() const override { return 32; }
unsigned getPackedFormatBitSize() const override { return 16; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is getPackedFormatBitSize really getPackedFormatBitSizeDpasA?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, will be renamed. And I think it should be a member of dpas instruction per uarch instance. We might want to split this PR into two parts to have a substantial discussion in each: (1) uArch modification and (2) propagation option and uArch application in passes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For C, is it the same as B (32)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For C, the result is f32 so no packing needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a generic lane data calculation for dpas operands, wouldn't the following format be desired in the dpas propagation
packingFactor = dpasInst->getOperand*A/B/C*PackingBitSize() / dataElemBitwidth?

It is not so much about whether we actually consider "packing" C.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree that it should be part of DPAS.

Copy link
Contributor

@charithaintc charithaintc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall directtion looks great. Thanks for the excellent work!

Will do another round after your feedback.

@@ -42,12 +40,61 @@ struct Xe2Plus : public uArch {
&instrs = {})
: uArch(archName, archDescription, regInfo, cacheInfo, instrs),
xeCore(xeCore) {}
int getSubgroupSize() const override { return 16; }
unsigned getPackedFormatBitSizeGatherScatter() const override { return 32; }
unsigned getPackedFormatBitSize() const override { return 16; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree that it should be part of DPAS.

bitWidth < targetinfo::packedSizeInBitsForGatherScatter
? targetinfo::packedSizeInBitsForGatherScatter / bitWidth
: 1;
constexpr int packingBitSizeGatherScatter{32};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still having hard coded numbers defeats the purpose of uArch iface. I think we have 2 alternatives here.

  1. Relax the IR and check these things during lowering when uArch is available.
  2. Define a base class implementation for certain things. I believe packing size for gather/scatter is common for all archs.

auto tdescTy = prefetch.getTensorDescType();
auto prefetchLayout = getDefaultSIMTLayoutInfo(tdescTy);

auto uArch = getUArch(getChipStr(prefetch).value_or(""));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think rather than defaulting to a certain arch, its better to fail the propagation early on if module lacks the arch info.

Assuming the arch can have some unintended effect for the user. We should let them know as early as possible that they are doing something wrong by not providing us an arch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added llvm_unreachable to getUArch(). If we call the function, then we definitely care about the returned object in any place of the codebase, and it must be a valid/known uarch.

if (tdescTy.getRank() == 1)
instData = {subgroupSize};
else
instData = {maxVecLength, subgroupSize};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here shouldn't we take the min of tile size and uArch max size?

if (store.getValueType().getRank() == 1)
instData = {subgroupSize};
else
instData = {maxVecLength, subgroupSize};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above. need to check against the provided store shape.

@akroviakov akroviakov force-pushed the akroviak/xegpu-enhance-layout-prop branch from 4b99cdd to 62be14d Compare October 28, 2025 09:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants